{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Automatic Speech Recognition combined with Speaker Diarization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "\n",
    "## Install dependencies\n",
    "!pip install wget\n",
    "!apt-get install sox libsndfile1 ffmpeg\n",
    "!pip install unidecode\n",
    "\n",
    "# ## Install NeMo\n",
    "BRANCH = 'r1.0.0rc1'\n",
    "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]\n",
    "\n",
    "## Install TorchAudio\n",
    "!pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the early years, speaker diarization algorithms were developed for speech recognition on multispeaker audio recordings to enable speaker adaptive processing, but also gained its own value as a stand-alone application over\n",
    "time to provide speaker-specific meta information for downstream tasks such as audio retrieval.\n",
    "Automatic Speech Recognition output when combined with Speaker labels has shown immense use in many tasks, ranging from analyzing telephonic conversation to decoding meeting transcriptions. \n",
    "\n",
    "In this tutorial we demonstrate how one can get ASR transcriptions combined with Speaker labels along with voice activity time stamps using NeMo asr collections. \n",
    "\n",
    "For detailed understanding of transcribing words with ASR refer to this [ASR tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/01_ASR_with_NeMo.ipynb), and for detailed understanding of speaker diarizing an audio refer to this [Diarization inference](https://github.com/NVIDIA/NeMo/blob/main/tutorials/speaker_recognition/Speaker_Diarization_Inference.ipynb) tutorial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's first import nemo asr and other libraries for visualization purposes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nemo.collections.asr as nemo_asr\n",
    "import numpy as np\n",
    "from IPython.display import Audio, display\n",
    "import librosa\n",
    "import os\n",
    "import wget\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We demonstrate this tutorial using merged an4 audio, that has two speakers(male and female) speaking dates in different formats. If not exists already download the data and listen to it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ROOT = os.getcwd()\n",
    "data_dir = os.path.join(ROOT,'data')\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "an4_audio_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.wav\"\n",
    "if not os.path.exists(os.path.join(data_dir,'an4_diarize_test.wav')):\n",
    "    AUDIO_FILENAME = wget.download(an4_audio_url, data_dir)\n",
    "else:\n",
    "    AUDIO_FILENAME = os.path.join(data_dir,'an4_diarize_test.wav')\n",
    "\n",
    "signal, sample_rate = librosa.load(AUDIO_FILENAME, sr=None)\n",
    "display(Audio(signal,rate=sample_rate))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_figure(signal,text='Audio',overlay_color=[]):\n",
    "    fig,ax = plt.subplots(1,1)\n",
    "    fig.set_figwidth(20)\n",
    "    fig.set_figheight(2)\n",
    "    plt.scatter(np.arange(len(signal)),signal,s=1,marker='o',c='k')\n",
    "    if len(overlay_color):\n",
    "        plt.scatter(np.arange(len(signal)),signal,s=1,marker='o',c=overlay_color)\n",
    "    fig.suptitle(text, fontsize=16)\n",
    "    plt.xlabel('time (secs)', fontsize=18)\n",
    "    plt.ylabel('signal strength', fontsize=14);\n",
    "    plt.axis([0,len(signal),-0.5,+0.5])\n",
    "    time_axis,_ = plt.xticks();\n",
    "    plt.xticks(time_axis[:-1],time_axis[:-1]/sample_rate);\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "plot the audio "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_figure(signal)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We start our demonstration by first transcribing the audio using our pretrained model `QuartzNet15x5Base-En` and use the CTC output probabilities to get timestamps for words spoken. We then later use these timestamps to get speaker label information using speaker diarizer model. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download and load pretrained quartznet asr model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load model\n",
    "asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name='QuartzNet15x5Base-En', strict=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Transcribe the audio "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "files = [AUDIO_FILENAME]\n",
    "transcript = asr_model.transcribe(paths2audio_files=files)[0]\n",
    "print(f'Transcript: \"{transcript}\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get CTC log probabilities with output labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# softmax implementation in NumPy\n",
    "def softmax(logits):\n",
    "    e = np.exp(logits - np.max(logits))\n",
    "    return e / e.sum(axis=-1).reshape([logits.shape[0], 1])\n",
    "\n",
    "# let's do inference once again but without decoder\n",
    "logits = asr_model.transcribe(files, logprobs=True)[0]\n",
    "probs = softmax(logits)\n",
    "\n",
    "# 20ms is duration of a timestep at output of the model\n",
    "time_stride = 0.02\n",
    "\n",
    "# get model's alphabet\n",
    "labels = list(asr_model.decoder.vocabulary) + ['blank']\n",
    "labels[0] = 'space'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use CTC labels for voice activity detection. To detect speech and non-speech segments in the audio, we use blank and space labels in the CTC outputs. Consecutive labels with spaces or blanks longer than a threshold are considered non-speech segments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "blanks = []\n",
    "\n",
    "state = ''\n",
    "idx_state = 0\n",
    "\n",
    "if np.argmax(probs[0]) == 28:\n",
    "    state = 'blank'\n",
    "\n",
    "for idx in range(1, probs.shape[0]):\n",
    "    current_char_idx = np.argmax(probs[idx])\n",
    "    if state == 'blank' and current_char_idx != 0 and current_char_idx != 28:\n",
    "        blanks.append([idx_state, idx-1])\n",
    "        state = ''\n",
    "    if state == '':\n",
    "        if current_char_idx == 28:\n",
    "            state = 'blank'\n",
    "            idx_state = idx\n",
    "\n",
    "if state == 'blank':\n",
    "    blanks.append([idx_state, len(probs)-1])\n",
    "\n",
    "threshold=20 #minimun width to consider non-speech activity \n",
    "non_speech=list(filter(lambda x:x[1]-x[0]>threshold,blanks)) \n",
    "\n",
    "# get timestamps for space symbols\n",
    "spaces = []\n",
    "\n",
    "state = ''\n",
    "idx_state = 0\n",
    "\n",
    "if np.argmax(probs[0]) == 0:\n",
    "    state = 'space'\n",
    "\n",
    "for idx in range(1, probs.shape[0]):\n",
    "    current_char_idx = np.argmax(probs[idx])\n",
    "    if state == 'space' and current_char_idx != 0 and current_char_idx != 28:\n",
    "        spaces.append([idx_state, idx-1])\n",
    "        state = ''\n",
    "    if state == '':\n",
    "        if current_char_idx == 0:\n",
    "            state = 'space'\n",
    "            idx_state = idx\n",
    "\n",
    "if state == 'space':\n",
    "    spaces.append([idx_state, len(pred)-1])\n",
    "# calibration offset for timestamps: 180 ms\n",
    "offset = -0.18\n",
    "\n",
    "# split the transcript into words\n",
    "words = transcript.split()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Frame level stamps for non speech frames "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(non_speech)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "write to rttm type file for later use in extracting speaker labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "frame_offset=offset/time_stride\n",
    "speech_labels=[]\n",
    "uniq_id = os.path.basename(AUDIO_FILENAME).split('.')[0]\n",
    "with open(uniq_id+'.rttm','w') as f:\n",
    "    for idx in range(len(non_speech)-1):\n",
    "        start = (non_speech[idx][1]+frame_offset)*time_stride\n",
    "        end = (non_speech[idx+1][0]+frame_offset)*time_stride\n",
    "        f.write(\"SPEAKER {} 1 {:.3f} {:.3f} <NA> <NA> speech <NA>\\n\".format(uniq_id,start,end-start))\n",
    "        speech_labels.append(\"{:.3f} {:.3f} speech\".format(start,end))\n",
    "    if non_speech[-1][1] < len(probs):\n",
    "        start = (non_speech[-1][1]+frame_offset)*time_stride\n",
    "        end = (len(probs)+frame_offset)*time_stride\n",
    "        f.write(\"SPEAKER {} 1 {:.3f} {:.3f} <NA> <NA> speech <NA>\\n\".format(uniq_id,start,end-start))\n",
    "        speech_labels.append(\"{:.3f} {:.3f} speech\".format(start,end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Time stamps for speech frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(speech_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "COLORS=\"b g c m y\".split()\n",
    "def get_color(signal,speech_labels,sample_rate=16000):\n",
    "    c=np.array(['k']*len(signal))\n",
    "    for time_stamp in speech_labels:\n",
    "        start,end,label=time_stamp.split()\n",
    "        start,end = int(float(start)*16000),int(float(end)*16000),\n",
    "        if label == \"speech\":\n",
    "            code = 'red'\n",
    "        else:\n",
    "            code = COLORS[int(label.split('_')[-1])]\n",
    "        c[start:end]=code\n",
    "    \n",
    "    return c "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With voice activity time stamps extracted from CTC outputs, here we show the Voice Activity signal in <span style=\"color:red\">**red**</span> color and background speech in **black** color"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "color=get_color(signal,speech_labels)\n",
    "show_figure(signal,'an4 audio signal with vad',color)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use helper function from speaker utils to convert voice activity rttm file to manifest to diarize using \n",
    "speaker diarizer clustering inference model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.parts.speaker_utils import write_rttm2manifest\n",
    "output_dir = os.path.join(ROOT, 'oracle_vad')\n",
    "os.makedirs(output_dir,exist_ok=True)\n",
    "oracle_manifest = os.path.join(output_dir,'oracle_manifest.json')\n",
    "write_rttm2manifest(paths2audio_files=files,\n",
    "                    paths2rttm_files=[uniq_id+'.rttm'],\n",
    "                    manifest_file=oracle_manifest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat {output_dir}/oracle_manifest.json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up diarizer model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from omegaconf import OmegaConf\n",
    "MODEL_CONFIG = os.path.join(data_dir,'speaker_diarization.yaml')\n",
    "if not os.path.exists(MODEL_CONFIG):\n",
    "    config_url = \"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_recognition/conf/speaker_diarization.yaml\"\n",
    "    MODEL_CONFIG = wget.download(config_url,data_dir)   \n",
    "    \n",
    "config = OmegaConf.load(MODEL_CONFIG)\n",
    "\n",
    "pretrained_speaker_model='speakerdiarization_speakernet'\n",
    "config.diarizer.paths2audio_files = files\n",
    "config.diarizer.out_dir = output_dir #Directory to store intermediate files and prediction outputs\n",
    "\n",
    "config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n",
    "# Ignoring vad we just need to pass the manifest file we created\n",
    "config.diarizer.speaker_embeddings.oracle_vad_manifest = oracle_manifest"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Diarize the audio at provided time stamps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.models import ClusteringDiarizer\n",
    "oracle_model = ClusteringDiarizer(cfg=config);\n",
    "oracle_model.diarize();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.parts.speaker_utils import rttm_to_labels\n",
    "pred_rttm=os.path.join(output_dir,'pred_rttms',uniq_id+'.rttm')\n",
    "labels=rttm_to_labels(pred_rttm)\n",
    "print(\"speaker labels with time stamps\\n\",labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let us see the audio plot color coded per speaker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "color=get_color(signal,labels)\n",
    "show_figure(signal,'audio with speaker labels',color)\n",
    "display(Audio(signal,rate=16000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally transcribe audio with time stamps and speaker label information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_prev = 0\n",
    "idx=0\n",
    "start_point,end_point,speaker=labels[idx].split()\n",
    "print(\"{} [{:.2f} - {:.2f} sec]\".format(speaker,float(start_point),float(end_point)),end=\" \")\n",
    "for j, spot in enumerate(spaces):\n",
    "    pos_end = offset + (spot[0]+spot[1])/2*time_stride\n",
    "    if pos_prev < float(end_point):\n",
    "        print(words[j],end=\" \")\n",
    "    else:\n",
    "        print()\n",
    "        idx+=1\n",
    "        start_point,end_point,speaker=labels[idx].split()\n",
    "        print(\"{} [{:.2f} - {:.2f} sec]\".format(speaker,float(start_point),float(end_point)),end=\" \")\n",
    "        print(words[j],end=\" \")\n",
    "    pos_prev = pos_end\n",
    "\n",
    "\n",
    "print(words[j+1],end=\" \")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}